from pathlib import Path
from typing import Any, Callable, Dict, Tuple, Union

import mne
import numpy as np

from .base_dataset import BaseDataset
from ...utils import get_random_dir_path


def default_read_fn(file_path, **kwargs):
    # Load EEG file
    raw = mne.io.read_raw(file_path)
    # Convert raw to epochs
    epochs = mne.make_fixed_length_epochs(raw, duration=1)
    # Return EEG data
    return epochs


class FolderDataset(BaseDataset):
    '''
    Read EEG samples and their corresponding labels from a fixed folder structure. This class allows two kinds of common file structures, :obj:`subject_in_label` and :obj:`label_in_subject`. Here, :obj:`subject_in_label` corresponds to the following file structure:

    .. code-block:: python

        tree
        # outputs
        label01
        |- sub01.edf
        |- sub02.edf
        label02
        |- sub01.edf
        |- sub02.edf

    And :obj:`label_in_subject` corresponds to the following file structure:

    .. code-block:: python

        tree
        # outputs
        sub01
        |- label01.edf
        |- label02.edf
        sub02
        |- label01.edf
        |- label02.edf

    An example dataset for GNN-based methods:

    .. code-block:: python

        from torcheeg.datasets import FolderDataset
        from torcheeg import transforms

        sfreq = 128  # Sampling rate
        n_channels = 14  # Number of channels
        duration = 5  # Data collected for 5 seconds
        for i in range(num_files):
            n_samples = sfreq * duration
            data = np.random.randn(n_channels, n_samples)

            ch_names = [f'ch_{i+1:03}' for i in range(n_channels)]
            ch_types = ['eeg'] * n_channels
            info = mne.create_info(ch_names, sfreq, ch_types)
            raw = mne.io.RawArray(data, info)

            file_name = f'sub{i+1}.fif'
            file_path = os.path.join('./root_folder/', file_name)
            raw.save(file_path)

        label_map = {'folder1': 0, 'folder2': 1}
        dataset = FolderDataset(root_path='./root_folder',
                                structure='subject_in_label',
                                num_channel=14,
                                online_transform=transforms.ToTensor(),
                                label_transform=transforms.Compose([
                                    transforms.Select('label'),
                                    transforms.Lambda(lambda x: label_map[x])
                                ]),
                                num_worker=4)

    Args:
        root_path (str): The path to the root folder. (default: :obj:`'./folder'`)
        structure (str): Folder structure, which affects how labels and subjects are mapped to EEG signal samples. Please refer to the above description of the structure of the two folders to select the correct parameters. (default: :obj:`'subject_in_label'`)
        read_fn (Callable): Method for reading files in a folder. By default, this class provides methods for reading files using :obj:`mne.io.read_raw`. At the same time, we allow users to pass in custom file reading methods. The first input parameter of whose is file_path, and other parameters are additional parameters passed in when the class is initialized. For example, you can pass :obj:`chunk_size=32` to :obj:`FolderDataset`, then :obj:`chunk_size` will be received here.
        online_transform (Callable, optional): The transformation of the EEG signals and baseline EEG signals. The input is a :obj:`np.ndarray`, and the ouput is used as the first and second value of each element in the dataset. (default: :obj:`None`)
        offline_transform (Callable, optional): The usage is the same as :obj:`online_transform`, but executed before generating IO intermediate results. (default: :obj:`None`)
        label_transform (Callable, optional): The transformation of the label. The input is an information dictionary, and the ouput is used as the third value of each element in the dataset. (default: :obj:`None`)
        io_path (str): The path to generated unified data IO, cached as an intermediate result. If set to None, a random path will be generated. (default: :obj:`None`)
        io_size (int): Maximum size database may grow to; used to size the memory mapping. If database grows larger than ``map_size``, an exception will be raised and the user must close and reopen. (default: :obj:`1048576`)
        io_mode (str): Storage mode of EEG signal. When io_mode is set to :obj:`lmdb`, TorchEEG provides an efficient database (LMDB) for storing EEG signals. LMDB may not perform well on limited operating systems, where a file system based EEG signal storage is also provided. When io_mode is set to :obj:`pickle`, pickle-based persistence files are used. When io_mode is set to :obj:`memory`, memory are used. (default: :obj:`lmdb`)
        num_worker (int): Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: :obj:`0`)
        verbose (bool): Whether to display logs during processing, such as progress bars, etc. (default: :obj:`True`)
    '''

    def __init__(self,
                 root_path: str = './folder',
                 structure: str = 'subject_in_label',
                 read_fn: Union[None, Callable] = default_read_fn,
                 online_transform: Union[None, Callable] = None,
                 offline_transform: Union[None, Callable] = None,
                 label_transform: Union[None, Callable] = None,
                 io_path: Union[None, str] = None,
                 io_size: int = 1048576,
                 io_mode: str = 'lmdb',
                 num_worker: int = 0,
                 verbose: bool = True,
                 **kwargs):
        if io_path is None:
            io_path = get_random_dir_path(dir_prefix='datasets')

        # pass all arguments to super class
        params = {
            'root_path': root_path,
            'structure': structure,
            'read_fn': read_fn,
            'online_transform': online_transform,
            'offline_transform': offline_transform,
            'label_transform': label_transform,
            'io_path': io_path,
            'io_size': io_size,
            'io_mode': io_mode,
            'num_worker': num_worker,
            'verbose': verbose
        }
        params.update(kwargs)
        super().__init__(**params)
        # save all arguments to __dict__
        self.__dict__.update(params)

    @staticmethod
    def process_record(file: Any = None,
                       offline_transform: Union[None, Callable] = None,
                       read_fn: Union[None, Callable] = None,
                       **kwargs):

        file_path, subject_id, label = file

        trial_samples = read_fn(file_path, **kwargs)
        events = [i[0] for i in trial_samples.events]
        events.append(
            events[-1] +
            np.diff(events)[0])  # time interval between all events are same

        write_pointer = 0
        for i, trial_signal in enumerate(trial_samples.get_data()):
            t_eeg = trial_signal
            if not offline_transform is None:
                t = offline_transform(eeg=trial_signal)
                t_eeg = t['eeg']

            clip_id = f'{subject_id}_{label}_{write_pointer}'
            write_pointer += 1

            record_info = {
                'subject_id': subject_id,
                'start_at': events[i],
                'end_at': events[i + 1],
                'clip_id': clip_id,
                'label': label
            }

            yield {'eeg': t_eeg, 'key': clip_id, 'info': record_info}

    def set_records(self,
                    root_path: str = './folder',
                    structure: str = 'subject_in_label',
                    **kwargs):
        # get all the subfolders
        subfolders = [str(i) for i in Path(root_path).iterdir() if i.is_dir()]
        # get all the files in the subfolders
        file_path_list = []
        for subfolder in subfolders:
            file_path_list += [
                str(i) for i in Path(subfolder).iterdir() if i.is_file()
            ]
        # get the subject id
        if structure == 'subject_in_label':
            # get the file name without the extension
            subjects = [i.split('/')[-1].split('.')[0] for i in file_path_list]
            labels = [i.split('/')[-2] for i in file_path_list]
        elif structure == 'label_in_subject':
            subjects = [i.split('/')[-2] for i in file_path_list]
            labels = [i.split('/')[-1].split('.')[0] for i in file_path_list]
        else:
            raise ValueError('Unknown folder mode: {}'.format(structure))

        file_path_subject_label = list(zip(file_path_list, subjects, labels))
        return file_path_subject_label

    def __getitem__(self, index: int) -> Tuple[any, any, int, int, int]:
        info = self.read_info(index)

        eeg_index = str(info['clip_id'])
        eeg_record = str(info['_record_id'])
        eeg = self.read_eeg(eeg_record, eeg_index)

        signal = eeg
        label = info

        if self.online_transform:
            signal = self.online_transform(eeg=eeg)['eeg']

        if self.label_transform:
            label = self.label_transform(y=info)['y']

        return signal, label

    @property
    def repr_body(self) -> Dict:
        return dict(
            super().repr_body, **{
                'root_path': self.root_path,
                'chunk_size': self.chunk_size,
                'overlap': self.overlap,
                'num_channel': self.num_channel,
                'online_transform': self.online_transform,
                'offline_transform': self.offline_transform,
                'label_transform': self.label_transform,
                'before_trial': self.before_trial,
                'after_trial': self.after_trial,
                'num_worker': self.num_worker,
                'verbose': self.verbose,
                'io_size': self.io_size
            })
